## Exact C1 via single integral
compute_C1 <- function(a1, b1, a2, b2, theta_star) {
  integrand <- function(p1) {
    dbeta(p1, a1, b1) * pbeta(p1 - theta_star, a2, b2)
  }
  integrate(integrand, lower = theta_star, upper = 1)$value
}

## Calibration helper (works for any q)
calibrate <- function(post_p, C1, C0, q) {
  a0 <- (1 - q) / C0
  a1 <- q       / C1
  return((a1 * post_p) / (a0 + (a1 - a0) * post_p))
}

# -------------------------------------------------------------------------
# Posterior P(H1) for two‐arm binary with Monte Carlo ---------------------
# -------------------------------------------------------------------------
post_prob_H1_bin <- function(Y_t, Y_c, n, theta_star,
                             a1 = 0.5, b1 = 0.5, a2 = 0.5, b2 = 0.5,
                             q  = 0.5, sim = 1e4) {
  
  # a) draw posterior samples
  pt <- rbeta(sim, a1 + Y_t, b1 + n - Y_t)
  pc <- rbeta(sim, a2 + Y_c, b2 + n - Y_c)
  post_p <- mean(pt - pc > theta_star)
  
  # b) compute C1, C0
  C1 <- compute_C1(a1, b1, a2, b2, theta_star)
  C0 <- 1 - C1
  
  # c) calibrated posterior probability
  return(calibrate(post_p, C1, C0, q))
}

## Compute equation (8)
find_min_xi <- function(yd, n, theta_star, a1 = 0.5, b1 = 0.5,
                        a2 = 0.5, b2 = 0.5, q  = 0.5, sim = 1e4) {
  
  # 1. determine valid y2 so that y1 = yd + y2 lies in 0:n
  y0_min <- max(0, -yd)
  y0_max <- min(n, n - yd)
  y0_seq <- seq(y0_min, y0_max)
  
  # 2. compute corresponding y1
  y1_seq <- yd + y0_seq
  
  # 3. vectorise the xi-calculations
  xi_vec <- vapply(seq_along(y1_seq), 
    function(i) {
      post_prob_H1_bin(Y_t = y1_seq[i], Y_c = y0_seq[i], n = n,
        theta_star = theta_star, a1 = a1, b1 = b1, a2 = a2, b2 = b2,
        q = q, sim = sim)},
    numeric(1)
  )
  
  # 4. pick the minimal xi
  i_min <- which.min(xi_vec)
  return(list(y1 = y1_seq[i_min], y0 = y0_seq[i_min], min_xi = xi_vec[i_min]))
}

## Helper: log g (used only for optimisation fallback)
log_g <- function(theta0, y_bar0, e, theta_star){
  (y_bar0 + e)        * log(y_bar0 + theta_star) +
    (1 - y_bar0 - e)    * log(1 - y_bar0 - theta_star) +
    y_bar0              * log(theta0) +
    (1 - y_bar0)        * log(1 - theta0)
}

## Step function to find optimal theta0
find_theta0 <- function(y_bar0, e, theta_star, tol = 1e-10, maxiter = 1000){
  # ---- derivative of log g -------------------------------------------------
  dlogg <- function(theta0){
    (y_bar0 + e)/(theta0 + theta_star) - 
      (1 - y_bar0 - e)/(1 - theta0 - theta_star) + y_bar0/theta0 -
      (1 - y_bar0)/(1 - theta0)
  }
  
  # ---- search interval (avoid division by zero) ----------------------------
  lower <- 0 + 1e-12
  upper <- 1 - theta_star - 1e-12
  
  # ---- try uniroot first ---------------------------------------------------
  f.lower <- dlogg(lower)
  f.upper <- dlogg(upper)
  
  if (sign(f.lower) * sign(f.upper) < 0) {
    root <- uniroot(dlogg, interval = c(lower, upper), tol = tol)$root
    return(root)
  }
  
  # ---- if signs don't bracket, fall back to maximise log g -----------------
  # log g is strictly concave, so optimise() gives the root of dlogg
  opt <- optimise(function(t) -log_g(t, y_bar0, e, theta_star),
                  interval = c(lower, upper),
                  tol = tol,
                  maximum = FALSE)
  return(opt$minimum)
}

## Hepler function to find whether at optimal theta0, the cal{B} term >= 0
n_max_gt_0 <- function(e, y_bar0, theta_star, a1, b1, a2, b2, n){
  
  c1 <- beta(a1+n*(y_bar0+e),b1+n*(1-y_bar0-e))
  c2 <- beta(a2+n*y_bar0, b2+n*(1-y_bar0))
  c3 <- beta(a1+(n+1)*(y_bar0+e),b1+(n+1)*(1-y_bar0-e))
  c4 <- beta(a2+(n+1)*y_bar0, b2+(n+1)*(1-y_bar0))
  
  opt_theta0 <- find_theta0(y_bar0, e, theta_star)
  #print(opt_theta0)
  
  log_term1 <- log(c3) + log(c4) - log(c1) - log(c2)
  log_term2 <- (y_bar0+e)*log(opt_theta0+theta_star) + 
    (1-y_bar0-e)*log(1-opt_theta0-theta_star) + y_bar0*log(opt_theta0) +
    (1-y_bar0)*log(1-opt_theta0)
  
  B_2 <- exp(log_term1) - exp(log_term2)
  #print(B_2)
  return(B_2)
}

## Step function for integration
func_step <- function(e, y_bar0, theta_star, a1, b1, a2, b2, n){
  alpha1 <- a1 + n * (y_bar0 + e)
  beta1 <- b1 + n * (1 - y_bar0 - e)
  alpha2 <- a2 + n * y_bar0
  beta2 <- b2 + n * (1 - y_bar0)
  
  integrand_x <- function(x) {
    lower_y <- pmax(0, x - theta_star) 
    prob_y <- 1 - pbeta(lower_y, alpha2, beta2) 
    dbeta(x, alpha1, beta1) * prob_y
  }
  
  result <- integrate(f = integrand_x,lower = 0, upper = 1,
                      subdivisions = 1000, rel.tol = 1e-8)$value
  return(result)
}

# -------------------------------------------------------------------------
# Find n_min for a fixed y_bar0 -------------------------------------------
# -------------------------------------------------------------------------
find_n_min <- function(e, y_bar0, theta_star, a1, b1, a2, b2, n_max){
  
  # Check whether n_max is monotone
  n_max_val <- n_max_gt_0(e, y_bar0, theta_star, a1, b1, a2, b2, n_max)
  if(n_max_val < 0){
    return("Please increase n_max.")
  }
  
  all_val <- rep(NA, n_max)
  for(n in seq(1,n_max)){
    result <- func_step(e, y_bar0, theta_star, a1, b1, a2, b2, n) - 
      func_step(e, y_bar0, theta_star, a1, b1, a2, b2, n+1)
    all_val[n] <- result
  }
  
  if(max(all_val) <= 0){
    return(-1)
  }else{
    right_min <- rev(cummin(rev(all_val)))
    idx <- which(right_min > 0)[1]
    return(idx)
  }
}

# -------------------------------------------------------------------------
# BESS Algorithm 4: Compute the n_min for two-arm trials ------------------
# -------------------------------------------------------------------------
# Parameters:
#   e: evidence; theta_star: clinically minimum effect size; 
#   n_max: maximum candidate sample size;
#   a1, b1, a2, b2: hyperparameters;
find_n_min_BESS2 <- function(e, theta_star, n_max, a1 = 0.5, 
                             b1 = 0.5, a2 = 0.5, b2 = 0.5){
  
  bd_vals <- c(0, 1-e)
  if(a1 == b1 && a2 == b2 && a1 == a2){
    bd_vals <- c(bd_vals, (1-e)/2)
  }else{
    for(n in 1:n_max){
      bd_vals <- c(bd_vals, (1-e)/2 + (b1 - a1)/(2*n))
    }
  }
  bd_vals <- unique(bd_vals)
  vals_list <- rep(NA, length(bd_vals))
  counter <- 1
  for(y_bar0 in bd_vals){
    n_min <- find_n_min(e, y_bar0, theta_star, a1, b1, a2, b2, n_max)
    if(n_min == "Please increase n_max."){
      return(n_min)
    }else{
      vals_list[counter] <- n_min
      counter <- counter + 1
    }
  }
  if(min(vals_list) > 0){
    return(max(vals_list))
  }else{
    return("Please increase n_max.")
  }
}

# ------------------------------------------------------------------------
# Unit-Testing BESS Algo 3 -----------------------------------------------
# ------------------------------------------------------------------------

#set.seed(12345)
#theta_star <- 0.05
#e <- 0.1
#find_n_min_BESS2(e, theta_star, 500)

## Helper: “or fallback” for NULL
`%||%` <- function(x, y) if (is.null(x) || length(x) == 0) y else x

# -------------------------------------------------------------------------
# BESS Algorithm 2: Find sample size for two-arm trials -------------------
# -------------------------------------------------------------------------
# Parameters:
#   theta_star: clinically minimum effect size;
#   c: confidence level; e: evidence; 
#   a1 = b1 = a2 = b2 = 0.5: hyperparameters;
#   n_min, n_max: minimum and maximum candidate sample sizes; 
#   q = 0.5: prior prob of H_1; sim = 10000: simulation number;
#   method = {"BESS2", "BESS2p"}: BESS 2 or BESS 2';
#   y_bar0 = NA: the reference value for control arm. Cannot be NA if use BESS2p.
BESS_binary <- function(theta_star, c, e, n_min, n_max, 
                        a1 = 0.5, b1 = 0.5, a2 = 0.5, b2 = 0.5, 
                        q = 0.5, sim = 10000, method = c("BESS2", "BESS2p"), 
                        y_bar0 = NA){
  
  method <- match.arg(method)
  n_seq <- seq(n_min, n_max)
  
  # If at max, c is not attained, we shall use n_max as output
  if(method == "BESS2p"){
    if (is.null(ybar0)) {
      return("`ybar0` must be provided for method = 'BESS2p'.")
    }else{
      y0 <- floor(y_bar0*n_max)
      y1 <- floor((y_bar0 + e)*n_max)
      
      xi_max <- post_prob_H1_bin(y1, y0, n_max, theta_star, 
                                 a1 = a1, b1 = b1, a2 = a2, b2 = b2, 
                                 q = q, sim = sim)
      
      if(xi_max < c){
        return(list(n = n_max, yd = y1 - y0, n_seq = seq(n_min, n_max), 
                    vec = xi_max, yd_vec = y1 - y0, y1_vec = y1, y2_vec = y0))
      }else{
        y0_vec <- round(y_bar0*n_seq)
        y1_vec <- y0_vec + floor(e*n_seq)
        
        post_prob_H1_bin <- function(Y_t, Y_c, n, theta_star,
                                     a1 = 0.5, b1 = 0.5, a2 = 0.5, b2 = 0.5,
                                     q  = 0.5, sim = 1e4)
        
        xi_vec <- vapply(
          seq_along(n_seq),
          function(n) {
            post_prob_H1_sim(Y_t = y1_vec[n], Y_c = y0_vec[n], n = n_seq[n],
                             theta_star = theta_star, a1 = a1, b1 = b1, 
                             a2 = a2, b2 = b2, q = q, sim = sim)
          }, numeric(1)
        )
        
        n_hit <- which(xi_vec >= c)[1] %||% length(n_seq)
        
        return(list(n = n_seq[n_hit], yd = y1_vec[n_hit] - y0_vec[n_hit], 
                    n_seq = n_seq, xi_vec = xi_vec, yd_vec = y1_vec - y0_vec,
                    y1_vec = y1_vec, y0_vec = y0_vec))
      }
    }
  }else if(method == "BESS2"){
    yd_max <- floor(e*n_max)
    min_xi_max <- find_min_xi(yd_max, n_max, theta_star, a1 = a1, b1 = b1, 
                              a2 = a2, b2 = b2, q = q, sim = sim)$min_xi
    if(min_xi_max < c){
      return(list(n = n_max, yd = yd_max, n_seq = seq(n_min, n_max), 
                  xi_vec = numeric(0), yd_vec = numeric(0), 
                  y1_vec = numeric(0), y2_vec = numeric(0)))
    }
    
    results <- lapply(n_seq, function(n){
        yd_val <- floor(e * n)
        res <- find_min_xi(yd = yd_val, n = n, theta_star = theta_star,
                           a1 = a1, b1 = b1, a2 = a2, b2 = b2, q = q, sim = sim)
        c(xi = res$min_xi, yd = yd_val, y1 = res$y1, y0 = res$y0)
    })
    mat <- do.call(rbind, results)
    
    n_hit <- which(mat[ , "xi"] >= c)[1] %||% length(n_seq)
   
    return(list(n = n_seq[n_hit], yd = mat[n_hit, "yd"], n_seq = n_seq, 
                xi_vec = mat[, "xi"], yd_vec = mat[, "yd"],
                y1_vec = mat[, "y1"], y0_vec = mat[, "y0"]))
  }else{
    return("Please use a valid y_bar0 value for BESS 2'.")
  }
}

# ------------------------------------------------------------------------
# Unit-Testing Two-Arm Binary --------------------------------------------
# ------------------------------------------------------------------------

# Testing functions
#set.seed(12345)
#theta_star <- 0.05
#c <- 0.8
#e <- 0.1
#BESS_binary(theta_star, c, e, 1, 150)

# -------------------------------------------------------------------------
# BESS Design -------------------------------------------------------------
# -------------------------------------------------------------------------
# Parameters:
#   theta_H, theta_L: true high and low dose response rates;
#   theta_star: clinically minimum effect size;
#   c: confidence level for success; c_s: confidence level for futility; 
#   n0: interim sample size; a1 = b1 = a2 = b2 = 0.5: hyperparameters;
#   n_min, n_max: minimum and maximum candidate sample sizes; 
BESS_SSR_design <- function(theta_H, theta_L, theta_star, c, c_s, n0, 
                            a1 = 0.5, b1 = 0.5, a2 = 0.5, b2 = 0.5, q = 0.5,
                            n_min = 1, n_max = 200) {
  
  y_H <- rbinom(n0, 1, theta_H)
  y_L <- rbinom(n0, 1, theta_L)
  
  pH1 <- post_prob_H1_bin(sum(y_L), sum(y_H), n0, -theta_star, 
                          a1 = a1, b1 = b1, a2 = a2, b2 = b2, q = q)
  bess_int_dec <- NA
  if(pH1 >= c){        # Stop for success
    bess_int_dec <- 1
  }else if(pH1 <= c_s){# Stop for futility
    bess_int_dec <- 0
  }else{               # Continue enrollment
    bess_int_dec <- 2
  }
  
  y_H_s <- NA
  y_L_s <- NA
  e_s <- NA
  n_s <- NA
  pH1_s <- NA
  bess_dec <- NA
  if(bess_int_dec == 2){ 
    # Compute BESS with informative prior
    a1_ip <- a1+sum(y_L)
    b1_ip <- b1+n0-sum(y_L)
    a2_ip <- a2+sum(y_H)
    b2_ip <- b2+n0-sum(y_H)
    
    # Compute posterior pred e
    post_mean_theta_L <- a1_ip/(a1_ip+b1_ip)
    post_mean_theta_H <- a2_ip/(a2_ip+b2_ip)
    e_s <- post_mean_theta_L - post_mean_theta_H
    
    # BESS additional sample size
    ss_result <- BESS_binary(-theta_star, c, e_s, n_min, n_max, 
                             a1 = a1_ip, b1 = b1_ip, a2 = a2_ip, b2 = b2_ip, 
                             q = q, method = "BESS2")
    n_s <- ss_result$n
    
    # generate additional outcome
    y_H_s <- rbinom(n_s, 1, theta_H)
    y_L_s <- rbinom(n_s, 1, theta_L)
    
    pH1_s <- post_prob_H1_bin(sum(y_L_s), sum(y_H_s), n_s, -theta_star, 
                              a1 = a1_ip, b1 = b1_ip, a2 = a2_ip, b2 = b2_ip, 
                              q = q)
    bess_dec <- 0
    if(pH1_s >= c){
      bess_dec <- 1
    }
  }
  
  return(list(y_H_n0 = y_H, y_L_n0 = y_L, pH1_int = pH1, 
              bess_int_dec = bess_int_dec, y_H_ns = y_H_s, 
              y_L_ns = y_L_s, e_s = e_s, n_s = n_s, 
              pH1_s = pH1_s, bess_fin_dec = bess_dec))
}

# -------------------------------------------------------------------------
# Standard SSE and Standard SSE with Interim ------------------------------
# -------------------------------------------------------------------------
# Parameters:
#   n_f: Frequentist sample size;
#   theta_H, theta_L: true high and low dose response rates;
#   theta_star: clinically minimum effect size; alpha: Type I error rate;
#   interim: T or F, T means there is an interim analysis, F means otherwise;
#   c: confidence level for success; c_s: confidence level for futility; 
#   n_0: interim sample size, needs to be non-NA when interim is T; 
#   a1 = b1 = a2 = b2 = 0.5: hyperparameters;
StandardSSE_bin_design <- function(n_f, theta_H, theta_L, theta_star,  
                                   alpha, interim = F, n_0 = NA,  
                                   a1 = 0.05, b1 = 0.05, a2 = 0.05, b2 = 0.05,
                                   q = 0.5, c = 0.7, c_s = 0.3){
  
  y_H <- rbinom(n_f, 1, theta_H)
  y_L <- rbinom(n_f, 1, theta_L)
  freq_dec <- NULL
  n_tot <- NA
  
  if(interim == T){
    
    if(is.na(n_0)){
      stop("When interim is True, there has to be an interim sample size n_0 > 0.")
    }
    
    y_H_50 <- y_H[1:n_0]
    y_L_50 <- y_L[1:n_0]
    
    pH1 <- post_prob_H1_bin(sum(y_L_50), sum(y_H_50), n_0, -theta_star, 
                            a1 = a1, b1 = b1, a2 = a2, b2 = b2, q = q)
    bess_int_dec <- NA
    if(pH1 >= c){        # Stop for success
      freq_dec <- 1
      n_tot <- n_0
    }else if(pH1 <= c_s){# Stop for futility
      freq_dec <- 0
      n_tot <- n_0
    }else{
      # Freq test result
      z_stats <- (mean(y_L) - mean(y_H) + theta_star)/
        (sqrt((mean(y_L)*(1-mean(y_L))+mean(y_H)*(1-mean(y_H)))/n_f))
      p_val <- pnorm(z_stats, lower.tail = F)
      n_tot <- n_f
      if(p_val < alpha){
        freq_dec <- 1
      }else{
        freq_dec <- 0
      }
    }
  }else{
    # Freq test result
    z_stats <- (mean(y_L) - mean(y_H) + theta_star)/
      (sqrt((mean(y_L)*(1-mean(y_L))+mean(y_H)*(1-mean(y_H)))/n_f))
    p_val <- pnorm(z_stats, lower.tail = F)
    freq_dec <- 0
    n_tot <- n_f
    if(p_val < alpha){
      freq_dec <- 1
    }
  }
  
  return(c(mean(y_H), mean(y_L), freq_dec, n_tot))
}


